#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

import torch

# Similar to Weight Normalization: https://arxiv.org/abs/1602.07868, but there are differences.
# Also similar to Weight Standardization: https://arxiv.org/abs/1903.10520, but there are differences.
# In this implementation:
# (1) Snmall std value of weights are clamped to eps, istead of blindly adding eps to std
# (2) The whole tensor can be jointly standardized (optional), instead of each output channel separately
# (3) The standardized weights are saved into the parameter in eval pass so that the stored weights can work with regular convolution as well.
# (4) ONNX export does not export the standardization operations, but only the standardized weights with regular convolution
# Make sure that the model state_dict saving and the ONNX export are done in eval model.
# Also make sure that your training does an eval pass at the end, so that the standardized weights are available in params.
class ConvWS2d(torch.nn.Conv2d):
    def __init__(self, *args, per_channel=True, **kwargs):
        super().__init__(*args, **kwargs)
        self.per_channel = per_channel

    def forward(self, input):
        weight = standardize_weight(self.weight)
        if not self.training:
            # store the standardized weight in the parameter
            # so that the stored weights can work with regular convolution as well.
            # note: this storing is done only in eval mode to save time
            # make sure that an eval mode run is done before saving the weights
            self.weight.data.copy_(weight.data)
            # detach the weight in eval mode to make sure that
            # the onnx graph does not have the above operations
            weight = weight.data
        #
        return self.conv2d_forward(input, weight)


def standardize_weight(weight, per_channel=True):
    if per_channel:
        wsz0 = weight.size(0)
        weight_mean = weight.view(wsz0, -1).mean(dim=1).view(wsz0, 1, 1, 1)
        weight_std = weight.view(wsz0, -1).std(dim=1).view(wsz0, 1, 1, 1)
    else:
        weight_mean = weight.mean()
        weight_std = weight.std()
    #
    weight_std = torch.clamp(weight_std, min=1e-5)
    weight = (weight - weight_mean) / weight_std
    return weight